import torch
from torch import nn
import torch.nn.functional as F

import numpy as np
import math
import scipy
import scipy.signal

from mmgan_styleganv2ada import bias_act, upfirdn2d
from mmgan_styleganv2ada import bias_act2ncnn, upfirdn2d2ncnn, normalize_2nd_moment2ncnn
import ncnn_utils as ncnn_utils


def modulated_conv2d(
    x,                  # Input tensor: [batch_size, in_channels, in_height, in_width]
    w,                  # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]
    s,                  # Style tensor: [batch_size, in_channels]
    demodulate  = True, # Apply weight demodulation?
    padding     = 0,    # Padding: int or [padH, padW]
    input_gain  = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
):
    batch_size = x.shape[0]
    out_channels, in_channels, kh, kw = w.shape

    # Pre-normalize inputs.
    if demodulate:
        w = w * w.square().mean([1,2,3], keepdim=True).rsqrt()
        s = s * s.square().mean().rsqrt()

    # Modulate weights.
    w = w.unsqueeze(0) # [NOIkk]
    w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]

    # Demodulate weights.
    if demodulate:
        dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
        w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]

    # Apply input scaling.
    if input_gain is not None:
        input_gain = input_gain.expand(batch_size, in_channels) # [NI]
        w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]

    # Execute as one fused op using grouped convolution.
    x = x.reshape(1, -1, *x.shape[2:])
    w = w.reshape(-1, in_channels, kh, kw)
    x = F.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)
    x = x.reshape(batch_size, -1, *x.shape[2:])
    return x


def modulated_conv2d2ncnn(
    ncnn_data,
    bottom_names,
    use_fp16,
    w_shape,
    demodulate  = True, # Apply weight demodulation?
    padding     = 0,    # Padding: int or [padH, padW]
    input_gain  = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
):
    x, w, s = bottom_names
    x = [x, ]
    w = [w, ]
    s = [s, ]
    out_channels, in_channels, kh, kw = w_shape

    # Pre-normalize inputs.
    if demodulate:
        temp = ncnn_utils.square(ncnn_data, w)
        temp = ncnn_utils.really_reduction(ncnn_data, temp, op='ReduceMean', dims=(1, 2, 3), keepdim=True)
        temp = ncnn_utils.rsqrt(ncnn_data, temp, eps=0.0, scale=None)
        temp = ncnn_utils.really_reshape(ncnn_data, temp, shape=(1, 1, 1, -1))
        w = ncnn_utils.F4DOp1D(ncnn_data, [w[0], temp[0]], dim=0, op='Mul')
        temp2 = ncnn_utils.square(ncnn_data, s)
        temp2 = ncnn_utils.really_reduction(ncnn_data, temp2, op='ReduceMean', dims=(0, ), keepdim=True)
        temp2 = ncnn_utils.rsqrt(ncnn_data, temp2, eps=0.0, scale=None)
        s = ncnn_utils.binaryOp(ncnn_data, [s[0], temp2[0]], op='Mul')

    # Modulate weights.
    w = ncnn_utils.F4DOp1D(ncnn_data, [w[0], s[0]], dim=1, op='Mul')

    # Demodulate weights.
    if demodulate:
        dcoefs = ncnn_utils.square(ncnn_data, w)
        # 不要使用ReduceSum，会超出半精度浮点数的最大表示范围65504.0，造成计算结果不准确。在rsqrt里指定scale来实现ReduceSum的效果。
        dcoefs = ncnn_utils.really_reduction(ncnn_data, dcoefs, op='ReduceMean', dims=(1, 2, 3), keepdim=True)
        scale = float(in_channels * kh * kw)
        dcoefs = ncnn_utils.rsqrt(ncnn_data, dcoefs, eps=1e-8, scale=scale)
        dcoefs = ncnn_utils.really_reshape(ncnn_data, dcoefs, shape=(1, 1, 1, -1))
        w = ncnn_utils.F4DOp1D(ncnn_data, [w[0], dcoefs[0]], dim=0, op='Mul')  # [OIkk]

    # Apply input scaling.
    if input_gain is not None:
        input_gain_ = input_gain.cpu().detach().numpy()
        w = ncnn_utils.MulConstant(ncnn_data, w, scale=input_gain_, bias=0.0)

    # Execute as one fused op using grouped convolution.
    x = ncnn_utils.Fconv2d(ncnn_data, [x[0], w[0]], padding=padding, dilation=1, groups=1)
    return x



def filter2d(x, f, padding=0, flip_filter=False, gain=1):
    padx0, padx1, pady0, pady1 = _parse_padding(padding)
    fw, fh = _get_filter_size(f)
    p = [
        padx0 + fw // 2,
        padx1 + (fw - 1) // 2,
        pady0 + fh // 2,
        pady1 + (fh - 1) // 2,
    ]
    return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain)


class FullyConnectedLayer(nn.Module):
    def __init__(self,
        in_features,                # Number of input features.
        out_features,               # Number of output features.
        activation      = 'linear', # Activation function: 'relu', 'lrelu', etc.
        bias            = True,     # Apply additive bias before the activation function?
        lr_multiplier   = 1,        # Learning rate multiplier.
        weight_init     = 1,        # Initial standard deviation of the weight tensor.
        bias_init       = 0,        # Initial value of the additive bias.
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.activation = activation
        self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
        bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
        self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
        self.weight_gain = lr_multiplier / np.sqrt(in_features)
        self.bias_gain = lr_multiplier

    def forward(self, x):
        w = self.weight.to(x.dtype) * self.weight_gain
        b = self.bias
        if b is not None:
            b = b.to(x.dtype)
            if self.bias_gain != 1:
                b = b * self.bias_gain

        if self.activation == 'linear' and b is not None:
            x = torch.addmm(b.unsqueeze(0), x, w.t())
        else:
            x = x.matmul(w.t())
            x = bias_act(x, b, act=self.activation)
        return x

    def export_ncnn(self, ncnn_data, bottom_names):
        x_dtype = torch.float32
        w = self.weight.to(x_dtype) * self.weight_gain
        b = self.bias
        if b is not None:
            b = b.to(x_dtype)
            if self.bias_gain != 1:
                b = b * self.bias_gain

        w_b = ncnn_utils.shell(ncnn_data, bottom_names, w.unsqueeze(2).unsqueeze(3), b)

        if self.activation == 'linear' and b is not None:
            w_name = w_b[0]
            b_name = w_b[1]
            # 传入Fmatmul层的w的形状是[out_C, in_C, 1, 1] 所以不要使用
            # wt_names = ncnn_utils.really_permute(ncnn_data, w_name, perm=(1, 0, 2, 3))  # [CD11] -> [DC11]   这句代码
            bottom_names = ncnn_utils.Fmatmul(ncnn_data, [bottom_names[0], w_name, b_name], w.shape)
        else:
            w_name = w_b[0]
            b_name = w_b[1]
            # 传入Fmatmul层的w的形状是[out_C, in_C, 1, 1] 所以不要使用
            # wt_names = ncnn_utils.really_permute(ncnn_data, w_name, perm=(1, 0, 2, 3))  # [CD11] -> [DC11]   这句代码
            bottom_names = ncnn_utils.Fmatmul(ncnn_data, [bottom_names[0], w_name], w.shape)
            bottom_names = bias_act2ncnn(ncnn_data, [bottom_names[0], b_name], act=self.activation)
        return bottom_names


class StyleGANv3_MappingNetwork(nn.Module):
    def __init__(self,
        z_dim,                      # Input latent (Z) dimensionality.
        c_dim,                      # Conditioning label (C) dimensionality, 0 = no labels.
        w_dim,                      # Intermediate latent (W) dimensionality.
        num_ws,                     # Number of intermediate latents to output.
        num_layers      = 2,        # Number of mapping layers.
        lr_multiplier   = 0.01,     # Learning rate multiplier for the mapping layers.
        w_avg_beta      = 0.998,    # Decay for tracking the moving average of W during training.
    ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.w_dim = w_dim
        self.num_ws = num_ws
        self.num_layers = num_layers
        self.w_avg_beta = w_avg_beta

        # Construct layers.
        self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None
        features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers
        for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
            layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
            setattr(self, f'fc{idx}', layer)
        self.register_buffer('w_avg', torch.zeros([w_dim]))

    def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
        if truncation_cutoff is None:
            truncation_cutoff = self.num_ws

        # Embed, normalize, and concatenate inputs.
        x = z.to(torch.float32)
        x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
        if self.c_dim > 0:
            y = self.embed(c.to(torch.float32))
            y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
            x = torch.cat([x, y], dim=1) if x is not None else y

        # Execute layers.
        for idx in range(self.num_layers):
            x = getattr(self, f'fc{idx}')(x)

        # Update moving average of W.
        if update_emas:
            self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))

        # Broadcast and apply truncation.
        x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
        if truncation_psi != 1:
            x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
        return x

    def export_ncnn(self, ncnn_data, bottom_names):
        x = None
        z, coeff = bottom_names
        if self.z_dim > 0:
            x = normalize_2nd_moment2ncnn(ncnn_data, [z, ])
        if self.c_dim > 0:
            raise NotImplementedError("not implemented.")

        w_avg_names = ncnn_utils.shell(ncnn_data, x, self.w_avg.unsqueeze(0).unsqueeze(0).unsqueeze(0), None, ncnn_weight_dims=1)  # [111W] 形状(ncnn)

        # Main layers.
        for idx in range(self.num_layers):
            layer = getattr(self, f'fc{idx}')
            x = layer.export_ncnn(ncnn_data, x)

        x = ncnn_utils.lerp(ncnn_data, w_avg_names + x + [coeff, ])
        return x


class SynthesisInput(nn.Module):
    def __init__(self,
        w_dim,          # Intermediate latent (W) dimensionality.
        channels,       # Number of output channels.
        size,           # Output spatial size: int or [width, height].
        sampling_rate,  # Output sampling rate.
        bandwidth,      # Output bandwidth.
    ):
        super().__init__()
        self.w_dim = w_dim
        self.channels = channels
        self.size = np.broadcast_to(np.asarray(size), [2])
        self.sampling_rate = sampling_rate
        self.bandwidth = bandwidth

        # Draw random frequencies from uniform 2D disc.
        freqs = torch.randn([self.channels, 2])
        radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
        freqs /= radii * radii.square().exp().pow(0.25)
        freqs *= bandwidth
        phases = torch.rand([self.channels]) - 0.5

        # Setup parameters and buffers.
        self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels]))
        self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0])
        self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image.
        self.register_buffer('freqs', freqs)
        self.register_buffer('phases', phases)

    def forward(self, w):
        # Introduce batch dimension.
        transforms = self.transform.unsqueeze(0) # [batch, row, col]
        freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
        phases = self.phases.unsqueeze(0) # [batch, channel]

        # Apply learned transformation.
        t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
        t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
        m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
        m_r[:, 0, 0] = t[:, 0]  # r'_c
        m_r[:, 0, 1] = -t[:, 1] # r'_s
        m_r[:, 1, 0] = t[:, 1]  # r'_s
        m_r[:, 1, 1] = t[:, 0]  # r'_c
        m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image.
        m_t[:, 0, 2] = -t[:, 2] # t'_x
        m_t[:, 1, 2] = -t[:, 3] # t'_y
        transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform.

        # Transform frequencies.
        phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
        freqs = freqs @ transforms[:, :2, :2]

        # Dampen out-of-band frequencies that may occur due to the user-specified transform.
        amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)

        # Construct sampling grid.
        theta = torch.eye(2, 3, device=w.device)
        theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
        theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
        grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)

        # Compute Fourier features.
        x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel]
        x = x + phases.unsqueeze(1).unsqueeze(2)
        x = torch.sin(x * (np.pi * 2))
        x = x * amplitudes.unsqueeze(1).unsqueeze(2)

        # Apply trainable mapping.
        weight = self.weight / np.sqrt(self.channels)
        x = x @ weight.t()

        # Ensure correct shape.
        x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
        return x

    def export_ncnn(self, ncnn_data, bottom_names, ws_i):
        ws0, ws1, mixing = bottom_names
        w = ncnn_utils.StyleMixingSwitcher(ncnn_data, [ws0, ws1, mixing], ws_i=ws_i)
        w = w[0]

        freqs_name = ncnn_utils.shell(ncnn_data, [ws0, ], self.freqs.unsqueeze(0).unsqueeze(0), None, ncnn_weight_dims=2)
        phases_name = ncnn_utils.shell(ncnn_data, [ws0, ], self.phases.unsqueeze(0).unsqueeze(0).unsqueeze(0), None, ncnn_weight_dims=1)

        # Apply learned transformation.
        t = self.affine.export_ncnn(ncnn_data, [w, ])

        temp_t = ncnn_utils.crop(ncnn_data, t, starts='1,%d' % (0, ), ends='1,%d' % (2, ), axes='1,0')
        temp_t = ncnn_utils.reduction(ncnn_data, temp_t, op='ReduceSumSquare', input_dims=1, dims=(0, ), keepdim=True)

        temp_t = ncnn_utils.rsqrt(ncnn_data, temp_t, eps=0.)

        # 最后是逐元素相乘
        t = [t[0], temp_t[0]]
        t = ncnn_utils.binaryOp(ncnn_data, t, op='Mul')

        transforms = ncnn_utils.Transforms(ncnn_data, t)
        transforms_0 = ncnn_utils.crop(ncnn_data, transforms, starts='1,%d' % (2, ), ends='1,%d' % (3, ), axes='1,1')
        transforms_0 = ncnn_utils.really_permute(ncnn_data, transforms_0, perm=(1, 0))
        phases = ncnn_utils.Fmatmul(ncnn_data, [freqs_name[0], transforms_0[0]], (1, 2))
        phases = ncnn_utils.really_reshape(ncnn_data, phases, (self.channels, ))
        phases = ncnn_utils.binaryOp(ncnn_data, [phases[0], phases_name[0]], op='Add')

        transforms_1 = ncnn_utils.crop(ncnn_data, transforms, starts='1,%d' % (0, ), ends='1,%d' % (2, ), axes='1,1')
        transforms_1 = ncnn_utils.really_permute(ncnn_data, transforms_1, perm=(1, 0))
        freqs = ncnn_utils.Fmatmul(ncnn_data, [freqs_name[0], transforms_1[0]], (2, 2))

        amplitudes = ncnn_utils.reduction(ncnn_data, freqs, op='ReduceSumSquare', input_dims=2, dims=(2,), keepdim=False)
        amplitudes = ncnn_utils.sqrt(ncnn_data, amplitudes)
        B = (self.sampling_rate / 2 - self.bandwidth)
        amplitudes = ncnn_utils.MulConstant(ncnn_data, amplitudes, scale=-1.0/B, bias=1.0 + self.bandwidth / B)
        amplitudes = ncnn_utils.clamp(ncnn_data, amplitudes, min_v=0.0, max_v=1.0)

        theta = torch.eye(2, 3, device=self.transform.device)
        theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
        theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
        grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)
        d0, d1, d2, d3 = grids.shape
        grids_name = ncnn_utils.shell(ncnn_data, [ws0, ], grids.reshape((1, 1, d0*d1*d2, d3)), None, ncnn_weight_dims=2)

        x = ncnn_utils.Fmatmul(ncnn_data, [grids_name[0], freqs[0]], (self.channels, 2))
        x = ncnn_utils.really_reshape(ncnn_data, x, (d0, d1, d2, self.channels))

        x = ncnn_utils.F4DOp1D(ncnn_data, [x[0], phases[0]], dim=3, op='Add')
        x = ncnn_utils.sin(ncnn_data, x, scale=2.0)

        x = ncnn_utils.F4DOp1D(ncnn_data, [x[0], amplitudes[0]], dim=3, op='Mul')


        # Apply trainable mapping.
        weight_ = self.weight / np.sqrt(self.channels)
        weight_name = ncnn_utils.shell(ncnn_data, [ws0, ], weight_.unsqueeze(0).unsqueeze(0), None, ncnn_weight_dims=2)

        x = ncnn_utils.really_reshape(ncnn_data, x, (d0*d1*d2, self.channels))
        x = ncnn_utils.Fmatmul(ncnn_data, [x[0], weight_name[0]], (self.channels, self.channels))
        x = ncnn_utils.really_reshape(ncnn_data, x, (d1, d2, self.channels))
        x = ncnn_utils.really_permute(ncnn_data, x, perm=[2, 0, 1])
        return x


def _get_filter_size(f):
    if f is None:
        return 1, 1
    assert isinstance(f, torch.Tensor)
    assert 1 <= f.ndim <= 2
    return f.shape[-1], f.shape[0] # width, height


def _parse_padding(padding):
    if isinstance(padding, int):
        padding = [padding, padding]
    assert isinstance(padding, (list, tuple))
    assert all(isinstance(x, (int, np.integer)) for x in padding)
    padding = [int(x) for x in padding]
    if len(padding) == 2:
        px, py = padding
        padding = [px, px, py, py]
    px0, px1, py0, py1 = padding
    return px0, px1, py0, py1


def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
    """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
    existing `upfirdn2n()` and `bias_act()` ops.
    """
    assert isinstance(x, torch.Tensor) and x.ndim == 4
    fu_w, fu_h = _get_filter_size(fu)
    fd_w, fd_h = _get_filter_size(fd)
    if b is not None:
        assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
    assert isinstance(up, int) and up >= 1
    assert isinstance(down, int) and down >= 1
    px0, px1, py0, py1 = _parse_padding(padding)
    assert gain == float(gain) and gain > 0
    assert slope == float(slope) and slope >= 0
    assert clamp is None or (clamp == float(clamp) and clamp >= 0)

    # Calculate output size.
    batch_size, channels, in_h, in_w = x.shape
    in_dtype = x.dtype
    out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
    out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down

    # Compute using existing ops.
    x = bias_act(x=x, b=b) # Apply bias.
    x = upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
    x = bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
    x = upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.

    # Check output shape & dtype.
    assert x.dtype == in_dtype
    return x


def filtered_lrelu2ncnn(ncnn_data, bottom_names, out_C, fu=None, fd=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
    x, b = bottom_names
    fu_w, fu_h = _get_filter_size(fu)
    fd_w, fd_h = _get_filter_size(fd)
    assert isinstance(up, int) and up >= 1
    assert isinstance(down, int) and down >= 1
    px0, px1, py0, py1 = _parse_padding(padding)
    assert gain == float(gain) and gain > 0
    assert slope == float(slope) and slope >= 0
    assert clamp is None or (clamp == float(clamp) and clamp >= 0)

    # Compute using existing ops.
    x = bias_act2ncnn(ncnn_data, [x, b]) # Apply bias.
    x = upfirdn2d2ncnn(ncnn_data, x, fu, out_C, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
    x = bias_act2ncnn(ncnn_data, x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
    x = upfirdn2d2ncnn(ncnn_data, x, fd, out_C, down=down, flip_filter=flip_filter) # Downsample.

    return x


class SynthesisLayer(nn.Module):
    def __init__(self,
        w_dim,                          # Intermediate latent (W) dimensionality.
        is_torgb,                       # Is this the final ToRGB layer?
        is_critically_sampled,          # Does this layer use critical sampling?
        use_fp16,                       # Does this layer use FP16?

        # Input & output specifications.
        in_channels,                    # Number of input channels.
        out_channels,                   # Number of output channels.
        in_size,                        # Input spatial size: int or [width, height].
        out_size,                       # Output spatial size: int or [width, height].
        in_sampling_rate,               # Input sampling rate (s).
        out_sampling_rate,              # Output sampling rate (s).
        in_cutoff,                      # Input cutoff frequency (f_c).
        out_cutoff,                     # Output cutoff frequency (f_c).
        in_half_width,                  # Input transition band half-width (f_h).
        out_half_width,                 # Output Transition band half-width (f_h).

        # Hyperparameters.
        conv_kernel         = 3,        # Convolution kernel size. Ignored for final the ToRGB layer.
        filter_size         = 6,        # Low-pass filter size relative to the lower resolution when up/downsampling.
        lrelu_upsampling    = 2,        # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.
        use_radial_filters  = False,    # Use radially symmetric downsampling filter? Ignored for critically sampled layers.
        conv_clamp          = 256,      # Clamp the output to [-X, +X], None = disable clamping.
        magnitude_ema_beta  = 0.999,    # Decay rate for the moving average of input magnitudes.
    ):
        super().__init__()
        self.w_dim = w_dim
        self.is_torgb = is_torgb
        self.is_critically_sampled = is_critically_sampled
        self.use_fp16 = use_fp16
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.in_size = np.broadcast_to(np.asarray(in_size), [2])
        self.out_size = np.broadcast_to(np.asarray(out_size), [2])
        self.in_sampling_rate = in_sampling_rate
        self.out_sampling_rate = out_sampling_rate
        self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling)
        self.in_cutoff = in_cutoff
        self.out_cutoff = out_cutoff
        self.in_half_width = in_half_width
        self.out_half_width = out_half_width
        self.conv_kernel = 1 if is_torgb else conv_kernel
        self.conv_clamp = conv_clamp
        self.magnitude_ema_beta = magnitude_ema_beta

        # Setup parameters and buffers.
        self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1)
        self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel]))
        self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
        self.register_buffer('magnitude_ema', torch.ones([]))

        # Design upsampling filter.
        self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
        assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
        self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1
        self.register_buffer('up_filter', self.design_lowpass_filter(
            numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate))

        # Design downsampling filter.
        self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
        assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate
        self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1
        self.down_radial = use_radial_filters and not self.is_critically_sampled
        self.register_buffer('down_filter', self.design_lowpass_filter(
            numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial))

        # Compute padding.
        pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling.
        pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling.
        pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters.
        pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).
        pad_hi = pad_total - pad_lo
        self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])]

    def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False):
        assert noise_mode in ['random', 'const', 'none'] # unused

        # Track input magnitude.
        if update_emas:
            with torch.autograd.profiler.record_function('update_magnitude_ema'):
                magnitude_cur = x.detach().to(torch.float32).square().mean()
                self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta))
        input_gain = self.magnitude_ema.rsqrt()

        # Execute affine layer.
        styles = self.affine(w)
        if self.is_torgb:
            weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2))
            styles = styles * weight_gain

        # Execute modulated conv2d.
        dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
        x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles,
            padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain)

        # Execute bias, filtered leaky ReLU, and clamping.
        gain = 1 if self.is_torgb else np.sqrt(2)
        slope = 1 if self.is_torgb else 0.2
        x = filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype),
            up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp)

        # Ensure correct shape and dtype.
        assert x.dtype == dtype
        return x

    def export_ncnn(self, ncnn_data, bottom_names, ws_i, noise_mode='random', force_fp32=False, update_emas=False):
        x, ws0, ws1, mixing = bottom_names
        w = ncnn_utils.StyleMixingSwitcher(ncnn_data, [ws0, ws1, mixing], ws_i=ws_i)
        w = w[0]

        input_gain = self.magnitude_ema.rsqrt()

        # Execute affine layer.
        styles = self.affine.export_ncnn(ncnn_data, [w, ])
        if self.is_torgb:
            weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2))
            styles = ncnn_utils.MulConstant(ncnn_data, styles, scale=weight_gain)

        w_b = ncnn_utils.shell(ncnn_data, w, self.weight, self.bias)

        # Execute modulated conv2d.
        # dtype = torch.float16 if (self.use_fp16 and not force_fp32) else torch.float32
        x = modulated_conv2d2ncnn(ncnn_data, [x, w_b[0], styles[0]], self.use_fp16, w_shape=self.weight.shape, padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain)

        # Execute bias, filtered leaky ReLU, and clamping.
        gain = 1 if self.is_torgb else np.sqrt(2)
        slope = 1 if self.is_torgb else 0.2
        x = filtered_lrelu2ncnn(ncnn_data, [x[0], w_b[1]], self.out_channels, fu=self.up_filter, fd=self.down_filter,
            up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp)

        return x

    @staticmethod
    def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
        assert numtaps >= 1

        # Identity filter.
        if numtaps == 1:
            return None

        # Separable Kaiser low-pass filter.
        if not radial:
            f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
            return torch.as_tensor(f, dtype=torch.float32)

        # Radially symmetric jinc-based filter.
        x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
        r = np.hypot(*np.meshgrid(x, x))
        f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
        beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
        w = np.kaiser(numtaps, beta)
        f *= np.outer(w, w)
        f /= np.sum(f)
        return torch.as_tensor(f, dtype=torch.float32)


class StyleGANv3_SynthesisNetwork(nn.Module):
    def __init__(self,
        w_dim,                          # Intermediate latent (W) dimensionality.
        img_resolution,                 # Output image resolution.
        img_channels,                   # Number of color channels.
        channel_base        = 32768,    # Overall multiplier for the number of channels.
        channel_max         = 512,      # Maximum number of channels in any layer.
        num_layers          = 14,       # Total number of layers, excluding Fourier features and ToRGB.
        num_critical        = 2,        # Number of critically sampled layers at the end.
        first_cutoff        = 2,        # Cutoff frequency of the first layer (f_{c,0}).
        first_stopband      = 2**2.1,   # Minimum stopband of the first layer (f_{t,0}).
        last_stopband_rel   = 2**0.3,   # Minimum stopband of the last layer, expressed relative to the cutoff.
        margin_size         = 10,       # Number of additional pixels outside the image.
        output_scale        = 0.25,     # Scale factor for the output image.
        num_fp16_res        = 4,        # Use FP16 for the N highest resolutions.
        **layer_kwargs,                 # Arguments for SynthesisLayer.
    ):
        super().__init__()
        self.w_dim = w_dim
        self.num_ws = num_layers + 2
        self.img_resolution = img_resolution
        self.img_channels = img_channels
        self.num_layers = num_layers
        self.num_critical = num_critical
        self.margin_size = margin_size
        self.output_scale = output_scale
        self.num_fp16_res = num_fp16_res

        # Geometric progression of layer cutoffs and min. stopbands.
        last_cutoff = self.img_resolution / 2 # f_{c,N}
        last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
        exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1)
        cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i]
        stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i]

        # Compute remaining layer parameters.
        sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i]
        half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i]
        sizes = sampling_rates + self.margin_size * 2
        sizes[-2:] = self.img_resolution
        channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max))
        channels[-1] = self.img_channels

        # Construct layers.
        self.input = SynthesisInput(
            w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]),
            sampling_rate=sampling_rates[0], bandwidth=cutoffs[0])
        self.layer_names = []
        for idx in range(self.num_layers + 1):
            prev = max(idx - 1, 0)
            is_torgb = (idx == self.num_layers)
            is_critically_sampled = (idx >= self.num_layers - self.num_critical)
            use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution)
            # use_fp16 = False
            layer = SynthesisLayer(
                w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16,
                in_channels=int(channels[prev]), out_channels= int(channels[idx]),
                in_size=int(sizes[prev]), out_size=int(sizes[idx]),
                in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]),
                in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx],
                in_half_width=half_widths[prev], out_half_width=half_widths[idx],
                **layer_kwargs)
            name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
            setattr(self, name, layer)
            self.layer_names.append(name)

    def forward(self, ws, **layer_kwargs):
        ws = ws.to(torch.float32).unbind(dim=1)

        # Execute layers.
        x = self.input(ws[0])
        for name, w in zip(self.layer_names, ws[1:]):
            x = getattr(self, name)(x, w, **layer_kwargs)
        if self.output_scale != 1:
            x = x * self.output_scale

        # Ensure correct shape and dtype.
        x = x.to(torch.float32)
        return x

    def export_ncnn(self, ncnn_data, bottom_names, **layer_kwargs):
        ws0, ws1, mixing = bottom_names

        # Execute layers.
        ws_idx = 0
        x = self.input.export_ncnn(ncnn_data, [ws0, ws1, mixing], ws_idx)
        ws_idx += 1
        for name in self.layer_names:
            layer = getattr(self, name)
            x = layer.export_ncnn(ncnn_data, [x[0], ws0, ws1, mixing], ws_idx, **layer_kwargs)
            ws_idx += 1
        if self.output_scale != 1:
            x = ncnn_utils.MulConstant(ncnn_data, x, scale=self.output_scale)

        # img = ncnn_utils.StyleganPost(ncnn_data, img)
        return x

